from runmodel_rules import *
import csv
import datetime
import pandas as pd


def add_csv(row):
    with open('8_21_Rules_Fold_withLossm4567.csv', 'a') as csvFile:
        writer = csv.writer(csvFile)
        writer.writerow(row)
        csvFile.close()

#
#
#We do not provide our corpus due to patient information
#ds corresponds to the text for each patient
#these must be provded by the user
models = [4,5,6,7]

for i in range(0,len(models)):
    mod_num = models[i]
    for j in range(0,10):
        fold_num=j
        top_acc=0
        for n in range(0,10):
            model, loss, acc, fpr, tpr, auc, test_preds, train_doc, test_doc, train_val, test_val,val_loss,train_loss = CNN_call(ds,mod_num,fold_num)
            now = datetime.datetime.now()
            row = [mod_num,fold_num, acc, loss, auc,fpr,tpr,train_loss,val_loss,ds, str(now)]
            add_csv(row)
            
            if acc > top_acc:
                top_acc = acc
                import matplotlib.pyplot as plt
                #Used this link for help with ROC curve https://stackoverflow.com/questions/25009284/how-to-plot-roc-curve-in-python
                plt.title('ROC')
                plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f, for accuracy %.2f' % (auc,acc))
                plt.legend(loc = 'lower right')
                plt.plot([0, 1], [0, 1],'r--')
                plt.xlim([0, 1])
                plt.ylim([0, 1])
                plt.ylabel('True Positive Rate')
                plt.xlabel('False Positive Rate')
                nm1='bestAcc_ROC_rules_model'+str(mod_num)+'_fold'+str(fold_num)+'.png'
                plt.savefig(nm1)
                plt.clf()
                nm2='bestAcc_ROC_rules_model'+str(mod_num)+'_fold'+str(fold_num)+'.h5'
                model.save(nm2)
            